
import logging
import argparse  #命令行参数解析包
import math
import os
import sys
import platform


from utils.option_parameter_config import input_colses
from utils.option_parameter_config import dataset_files


from sklearn import metrics
from time import strftime, localtime  #strftime用于格式化字符串得到时间

from transformers import BertModel  #Bert模型

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split   #pytorch中的数据加载器，随机分块

from model.Bert_Spc import BERT_SPC
from model.ian import IAN
from utils import data_utils
from utils.option_parameter_config import Construct_config
from utils.data_utils import Tokenizer4Bert
from utils.time_utils import Time_utils
from utils.record import RecordHyperParameter
from utils.GPU import getGPU
#for no-bert model
from utils.data_utils_noBert import build_tokenizer
from utils.data_utils_noBert import build_embedding_matrix


import utils

from tqdm import tqdm  # 如果手动停止，就会一直刷新而不在同一行


from transformers import BertTokenizer

# 设置日志记录，用于输出到控制台
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.addHandler(logging.StreamHandler(sys.stdout))  # 标准流Handler，将消息发送到标准输出流、错误流





class Instructor:  # 初始化结束之后开始构建，构建完之后就开始run
    def __init__(self, opt):  # opt是配置文件，传入配置的参数，在main中有体现
        self.opt = opt  # 赋值给内部
        tokenizers=None
        bert=None
        self.model=None
        self.recordname = 'record'

        # 配置日志，记录模型参数和训练过程
        log_dir = "./log_dir"
        if not os.path.exists(log_dir):
            os.makedirs(log_dir)
        log_file = '.\\log_dir\\{}-{}-{}.log'.format(opt.model_name, opt.dataset,
                                                     strftime("%y%m%d-%H%M", localtime()))  # 日志文件的名字
        filehandler = logger.addHandler(logging.FileHandler(log_file))  # 将log内容存放在文件中

        self.model_processing_time = Time_utils(logger, "model_processing_time")
        self.data_processing_time = Time_utils(logger, "data_processing_time")  # 记录时间用

        if platform.system().lower() == 'windows':
            BERT_PATH = 'D:/thesis/code/dataset/bert/bert-base-uncased'  # 如果是服务器跑，就用bert-base-uncased
        elif platform.system().lower() == 'linux':
            BERT_PATH = 'bert-base-uncased'  # 如果是服务器跑，就用bert-base-uncased



        #判断是不是bert模型
        if 'bert' in opt.model_name:
            tokenizer = Tokenizer4Bert(opt.max_seq_len, BERT_PATH)# 采用bert分词器，默认类型是bert-base-uncased
            bert = BertModel.from_pretrained(BERT_PATH, return_dict=False)  # 从bert_path引入
            self.model = opt.model_class(bert, opt).to(opt.device)  # 将网络模型实例化之后转移到GPU上
        else:
            '''
                如果不是bert模型，处理数据的词嵌入，矩阵都需要改变
                传入.dat文件，如果存在的话就直接读取，否则需要重新加载
            '''
            tokenizer = utils.data_utils_noBert.build_tokenizer(  # data_utils中的构建tokenizer方法
                fnames=[opt.dataset_file['train'], opt.dataset_file['test']],
                max_seq_len=opt.max_seq_len,  # 最大序列长度
                dat_fname='{0}_tokenizer.dat'.format(opt.dataset))   # 综合了训练集和测试集的词语，并建立起了对应的word2index，负责转换字符串为index
            embedding_matrix = utils.data_utils_noBert.build_embedding_matrix(   # 如果加载过一次的话就不用再重新生成了loading embedding_matrix: 300_laptop_embedding_matrix.dat
                word2idx=tokenizer.word2idx,   # 字符转换为index的字典
                embed_dim=opt.embed_dim,   # 嵌入层的大小
                dat_fname='{0}_{1}_embedding_matrix.dat'.format(str(opt.embed_dim), opt.dataset))  # 根据tokenizer得到的word2index index2vec
            self.model = opt.model_class(embedding_matrix, opt).to(opt.device)  # opt.model_class 是具体的class，传入参数用于初始化

        # 给出了数据集的地址和分词器，data_utils用于自定义创建自定义数据集
        self.trainset = utils.data_utils.ABSADataset(opt.dataset_file['train'], tokenizer)  # 创建训练数据集，分词器根据模型名字给出
        self.testset = utils.data_utils.ABSADataset(opt.dataset_file['test'], tokenizer)
        self.data_processing_time.compute_time()

        assert 0 <= opt.valset_ratio < 1  # 验证集比例判断，用于交叉验证,验证集也用于判断何时停止
        if opt.valset_ratio > 0:
            valset_len = int(len(self.trainset) * opt.valset_ratio)  # 取一部分训练集出来作为验证集
            self.trainset, self.valset = random_split(self.trainset, (len(self.trainset) - valset_len, valset_len))  # dataset.random_split
        else:  # 否则用测试集作为验证集
            self.valset = self.testset

        if opt.device.type == 'cuda':
            logger.info('cuda memory allocated: {}'.format(torch.cuda.memory_allocated(device=opt.device.index)))
        self._print_args()  # 打印参数

    def _print_args(self):
        n_trainable_params, n_nontrainable_params = 0, 0
        for p in self.model.parameters():
            n_params = torch.prod(torch.tensor(p.shape))  # 计算参数的个数
            if p.requires_grad:  # 如果是能够训练的参数
                n_trainable_params += n_params
            else:
                n_nontrainable_params += n_params
        logger.info(
            '> n_trainable_params: {0}, n_nontrainable_params: {1}'.format(n_trainable_params, n_nontrainable_params))
        logger.info('> training arguments:')
        for arg in vars(self.opt):  # 拿到opt对象中的属性11
            logger.info('>>> {0}: {1}'.format(arg, getattr(self.opt, arg)))

    def _reset_params(self):
        for child in self.model.children():
            if type(child) != BertModel:  # skip bert params
                for p in child.parameters():
                    if p.requires_grad:
                        if len(p.shape) > 1:
                            self.opt.initializer(p)   #初始化参数
                        else:
                            stdv = 1. / math.sqrt(p.shape[0])
                            torch.nn.init.uniform_(p, a=-stdv, b=stdv)


    def _train(self, criterion, optimizer, train_data_loader, val_data_loader):  # （损失函数，优化器，训练集，评估集）,保护方法
        max_val_acc = 0
        max_val_f1 = 0
        max_val_epoch = 0
        global_step = 0  # 初始化评估参数
        path = None
        loops = 6

        for loop in range(loops):
            print('-'*50,'this loop\'s total epoch：{}'.format(loop),'-'*50)
            print('-'*50,'using dataset:{}'.format(opt.dataset),'-'*50)
            self.opt.num_epoch += 5 # 5，10，15这样递增，直到30
            if not os.path.exists(self.recordName+'(epoch='+str(self.opt.num_epoch)+')'+'.xlsx'):  #如果不存在这个文件名
                self.recordName = '(epoch='+str(self.opt.num_epoch)+')'+'.xlsx'
            for i_epoch in tqdm(range(self.opt.num_epoch),desc='outer loop'):  # 训练多少个epoch
                logger.info('>' * 100)  # 打印>号
                logger.info('total epoch: {},epoch: {}'.format(self.opt.num_epoch,i_epoch))  # 第几个epoch
                n_correct, n_total, loss_total = 0, 0, 0  # 正确，错误数和总数量
                # switch model to training mode
                self.model.train()  # 网络模型已经在其他class中搭好了，执行的是对应模型里面的train，转换模型为训练模式
                cnt = 0
                for i_batch, batch in enumerate(train_data_loader):  # 对于训练集里面的每一个batch和batch中的第i个数据，len得到的是一共有多少个batch
                    '''
                        有batch就是批梯度下降算法
                        i_batch是int应该是batch里面第几条数据的意思，一个batch就有16条数据
                        batch就是打包好的16条数据
                        数据量/batchsize=batch的数量1
    
                        训练一次跑了145个batch
                    '''
                    global_step += 1
                    # clear gradient accumulators
                    optimizer.zero_grad()  # 每一个batch训练结束就清除一次梯度，清除叶子节点的累加梯度？

                    # 输入/输出的形式和目标，目标就是情感的极性，输入就是ABSA1Dataset建立的字典类型
                    inputs = [batch[col].to(self.opt.device) for col in self.opt.inputs_cols]  # 根据input_cols列表获取输入
                    outputs = self.model(inputs)  # 对于输入的数据输入到模型进行计算，得到的输出，输入是16条数据进行训练
                    targets = batch['polarity'].to(self.opt.device)  # 目标极性

                    loss = criterion(outputs, targets)  # 损失函数的定义，在run的时候给出。给损失函数取别名，line 168，用的是原有的交叉熵函数，输入输出进行对比，在batch中做损失，得到的是一个标量
                    loss.backward()   # 损失函数计算完成之后，得到的梯度值存在张量中
                    optimizer.step()  # adam优化器才是真正执行梯度下降更行参数的，执行单个优化步骤（参数更新），默认adam

                    n_correct += (torch.argmax(outputs, -1) == targets).sum().item()  # 如果预测结果正确
                    n_total += len(outputs)
                    loss_total += loss.item() * len(outputs)  # 整个epoch的loss
                    if global_step % self.opt.log_step == 0:  # 当训练执行了对10条input的预测之后就输出一次当前的loss
                        train_acc = n_correct / n_total
                        train_loss = loss_total / n_total  # 进行平均
                        logger.info('loss: {:.4f}, acc: {:.4f}'.format(train_loss, train_acc))

                # 跑完一个epoch之后的处理，首先跑验证集验证性能
                val_acc, val_f1 = self._evaluate_acc_f1(val_data_loader)  # 载入验证集进行测试
                logger.info('> val_acc: {:.4f}, val_f1: {:.4f}'.format(val_acc, val_f1))
                if val_acc > max_val_acc:  # 输出效果最好的一次
                    max_val_acc = val_acc  # 更新最高的acc
                    max_val_epoch = i_epoch  # 更新在第几个epoch出现最好乘积
                    if not os.path.exists('state_dict'):  # 如果不存在dict文件夹，就创建dict文件夹,保存模型参数
                        os.mkdir('state_dict')
                    path = 'state_dict/{0}_{1}_val_acc_{2}'.format(self.opt.model_name, self.opt.dataset, round(val_acc, 4)) # 将要保存的模型文件的名称
                    torch.save(self.model.state_dict(), path)  # 保存模型
                    logger.info('>> saved: {}'.format(path))  # 输出保存模型的信息
                if val_f1 > max_val_f1:  # 更新f1值
                    max_val_f1 = val_f1
                if i_epoch - max_val_epoch >= self.opt.patience:  # 提前停止，防止过拟合,当前的epoch和得到最大效果的epoch的距离大于patience时，停止训练，默认为5
                    print('>> early stop.')
                    break

            self.model_processing_time.compute_time()
        return path

    def _evaluate_acc_f1(self, data_loader):  # 评估acc,f1
        n_correct, n_total = 0, 0
        t_targets_all, t_outputs_all = None, None
        # switch model to evaluation mode
        self.model.eval()
        with torch.no_grad():  # 固定住训练好的参数不进行梯度计算，只进行前向传播
            for i_batch, t_batch in enumerate(data_loader):
                t_inputs = [t_batch[col].to(self.opt.device) for col in self.opt.inputs_cols]
                t_targets = t_batch['polarity'].to(self.opt.device)
                t_outputs = self.model(t_inputs)

                n_correct += (torch.argmax(t_outputs, -1) == t_targets).sum().item()
                n_total += len(t_outputs)

                if t_targets_all is None:
                    t_targets_all = t_targets
                    t_outputs_all = t_outputs
                else:
                    t_targets_all = torch.cat((t_targets_all, t_targets), dim=0)
                    t_outputs_all = torch.cat((t_outputs_all, t_outputs), dim=0)

        acc = n_correct / n_total
        f1 = metrics.f1_score(t_targets_all.cpu(), torch.argmax(t_outputs_all, -1).cpu(), labels=[0, 1, 2],
                              average='macro')  # 这里用上了metrics，评价macro-F1
        return acc, f1

    def run(self):  # 模型训练
        # Loss and Optimizer
        criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数，体现了正则化吗？
        _params = filter(lambda p: p.requires_grad, self.model.parameters())  # 过滤参数，剩下需要进行参数计算的，然后将参数网络中所有需要训练的参数放在_params中
        optimizer = self.opt.optimizer(_params, lr=self.opt.lr, weight_decay=self.opt.l2reg)  # 将所有需要求导计算的参数传入优化器中。如果只使用L2正则化，那么也可以利用优化器的weight_decay参数来实现。

        # 载入数据
        train_data_loader = DataLoader(dataset=self.trainset, batch_size=self.opt.batch_size, shuffle=True)  #直接引用原来的dataloader
        test_data_loader = DataLoader(dataset=self.testset, batch_size=self.opt.batch_size, shuffle=False)
        val_data_loader = DataLoader(dataset=self.valset, batch_size=self.opt.batch_size, shuffle=False)

        self._reset_params()

        best_model_path = self._train(criterion, optimizer, train_data_loader,
                                      val_data_loader)  # 开始训练(损失函数，优化器，训练集)，保护方法

        self.model.load_state_dict(torch.load(best_model_path)) # 训练完之后加载测试集表现最好的模型，测试其表现
        test_acc, test_f1 = self._evaluate_acc_f1(test_data_loader)  # 训练好模型之后在测试集上进行性能评估
        logger.info('>> test_acc: {:.4f}, test_f1: {:.4f}'.format(test_acc, test_f1))  # 打印输出

        result = {}
        result['acc'] = test_acc
        result['f1'] = test_f1
        result['time'] = self.model_processing_time.compute_time()
        recorder = RecordHyperParameter(result, opt, self.recordName)
        recorder.write()


models_name = utils.option_parameter_config.model_name  # 一次训练多个数据集
for model in models_name:
    for dataset in dataset_files:
        opt = Construct_config()  # 默认生成的输入，模型名等都确定了
        opt.dataset = dataset  # 使用不同的数据集
        opt.model_name = model
        opt.model_class = utils.option_parameter_config.model_classes[opt.model_name]
        opt.dataset_file = dataset_files[opt.dataset]  # 拿到数据集的文件
        opt.inputs_cols = input_colses[opt.model_name]  # 输入数据的形式，对于不同的网络模型输入的形式不同
        print('*'*50,"model {} is traing".format(opt.model_name),'*'*50)
        ins = Instructor(opt)  # 每次都根据新的opt选择新的输入
        ins.run()